{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "14efeb1e",
   "metadata": {},
   "source": [
    "## Deep Dive into CP + THD + AG + Striped>1 + SWA support for Transformer Engine JAX\n",
    "This feature was merged as part of [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/) and was made available in Transformer Engine v2.11. This document addresses 3 fundamental questions about the design considerations and the implementation logic for this feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16f738c7",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2f31119f",
   "metadata": {},
   "source": [
    "### Question 1: Why choose Striped>1 ?\n",
    "\n",
    "Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
    "\n",
    "#### I. Striped (`stripe_size=1`)\n",
    "- Such a staggered pattern is not supported by cuDNN\n",
    "- One possible way to express this with cuDNN support is by treating each `q` token as a segment, thereby producing 16 segments with varying `kv` token counts. However, this is very inefficient and does not scale well as max_seqlens increases\n",
    "\n",
    "```\n",
    "1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - 1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 4 4 - - -\n",
    "```\n",
    "<figure align=\"left\">\n",
    "<figcaption> Figure 1: Post load balancing using stripe_size=1 and post AG attention pattern for a single cp rank </figcaption>\n",
    "</figure>\n",
    "\n",
    "\n",
    "#### II. Striped > 1 (`stripe_size > 1`)\n",
    "- This pattern is supported by cuDNN, with a suggested `stripe_size=128`\n",
    "- The mask type supported by `CP + THD + AG + Striped>1 + SWA` is `PADDING_CAUSAL_MASK`; however, to express the pattern below, each rank executes THD + SWA using `PADDING_BOTTOM_RIGHT_CAUSAL_MASK`\n",
    "- `max_num_segments_for_rank` needs to be estimated. The estimation formula used is: `max_seqlens // (stripe_size * cp_size) + max_num_segments`\n",
    "\n",
    "```\n",
    "1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 - - - - - - - - - - - -\n",
    "```\n",
    "<figure align=\"left\">\n",
    "<figcaption> Figure 2: Post load balancing using stripe_size=4 and post AG attention pattern for a single cp rank </figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6eddfa7a",
   "metadata": {},
   "source": [
    "### Question 2: Why is there a need for separate helper functions for calculating seqlens and offsets ?\n",
    "\n",
    "The seqlens and offsets are calculated by the fused attn JAX primitives (both, CP and non-CP) so that they can be passed down to `fused_attn_arbitrary_seqlen_fwd_impl()` / `fused_attn_arbitrary_seqlen_bwd_impl()`, where it is translated before passing down to the cuDNN FE layer. The current (Transformer Engine v2.10) calculation of seqlens and offsets entails the CP primitive passing the sharded segment_ids, segment_pos, seq_lens, seq_offsets stuffed in a SequenceDescriptor object (a convenience class provided for packing these 4 tensors) to the `FusedAttnPrimitive`, which in turn calls `get_seqlens_and_offsets()` on the SequenceDescriptor object. \n",
    "\n",
    "If `get_seqlens_and_offsets()` receives a SequenceDescriptor object with seq_lens and seq_offsets populated and, segment_ids, segment_pos with size=0, it returns the seq_lens and seq_ofsets as it is (for e.g. `CP + BSHD + AG`). However, if `get_seqlens_and_offsets()` receives a SequenceDescriptor object with segment_ids and segment_pos populated and, seq_lens, seq_offsets with size=0, it first constructs a mask using the segment_ids and segment_pos and then extracts the seq_lens and seq_offsets from it and then returns it (for e.g. `CP + THD + P2P`).\n",
    "\n",
    "The problem with the current approach of calculating a mask followed by extracting the seq_lens and seq_offsets is that it is unable to express the patterns seen in `CP + THD + AG`. Below is one such example: \n",
    "\n",
    "```\n",
    "1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - -\n",
    "```\n",
    "<figure align=\"left\">\n",
    "<figcaption> Figure 3: Example 1 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) . </figcaption>\n",
    "</figure>\n",
    "\n",
    "Here, ideally, the two sections of the segment 3 should be split into two different segments (segment 3_1 formed using rows 9-12 and segment 3_2 formed using rows 13-16) as cuDNN does not support segment 3's entire staggered shape (as discussed earlier) , however, the mask route is unable to make this distinction, and it ends up treating it as one large segment thereby performing unnecessary computations of the padded regions in segment 3(rows 9-12 )\n",
    "\n",
    "In the below example, the mask route takes the `kv_seqlens` for segment 1 to be 6 and masks it using Bottom Right Causal Mask rather than taking `kv_seqlens` of 4 and masks it using Bottom Right Causal Mask, resulting in incorrect results\n",
    "\n",
    "```\n",
    "1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "```\n",
    "<figure align=\"left\">\n",
    "<figcaption> Figure 4: Example 2 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) </figcaption>\n",
    "</figure>\n",
    "\n",
    "The second case can be resolved in the mask path, but that would require adding CP specific details to the non-CP FusedAttn primitive which would contaminate it. Besides, resolving the first case would be even trickier with this approach. Due to it being incompatible with the design of FusedAttn primitive and inadequate to express the pattern needed for `CP + THD + AG` fully, separate helper functions were created which calculate the seqlens and seqoffsets, without creating a mask, hence also being O(N) space."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cc4a12c",
   "metadata": {},
   "source": [
    "### Question 3: What is the implementation logic for the separate helper functions ?\n",
    "\n",
    "This section discusses the implementation logic for two of these four helper functions which serve as a reference, as the other two are using similar principles. Consider the test example in the code block, for which, `cp_size=4`, `stripe_size=4`, `max_seqlens=64`, `num_segments=2` and no SWA for simplicity. seg_1 has 8 valid tokens + 13 padded tokens and seg_2 has 31 valid tokens + 1 padded token. The 0 is used to explicitly show the padded region of seg_1 which is reordered, but for computation purposes it is equivalent to any of the `-` marked elements.\n",
    "\n",
    "```\n",
    "segment_ids_q_0_reordered = segment_ids_kv_0_reordered = jnp.array([[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]])\n",
    "\n",
    "segment_pos_q_0_reordered = segment_pos_kv_0_reordered = jnp.array([[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]])\n",
    "\n",
    "segment_ids_kv_0_seed12_ag_inv_reordered = jnp.array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
    "\n",
    "segment_pos_kv_0_seed12_ag_inv_reordered= jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
    "```\n",
    "\n",
    "```\n",
    "1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - -\n",
    "- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - -\n",
    "```\n",
    "<figure align=\"left\">\n",
    "<figcaption> Figure 5: An example of post striped reordering and AG attention pattern on a single rank.</figcaption>\n",
    "</figure>\n",
    "\n",
    "#### I. Implementation logic for q_seqlens_for_striped_for_rank()\n",
    "**What is the objective/logic ?**\n",
    "- Create a new set of segment ids for this rank such that:\n",
    "    - It gets rid of padding information as it does not contribute to the seqlens calculation\n",
    "    - It has the ability to identify ”new segments” being created from the same original segment\n",
    "- Use this new set of segment ids to calculate the seqlens\n",
    "\n",
    "**Example walkthrough**\n",
    "1. Calculate the non-zero indices (where seg ids !=0)\n",
    "2. Calculate the valid seg ids and valid seg pos (i.e. index into seg ids and seg pos using the non-zero indices)\n",
    "    - `valid_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]]`\n",
    "    - `valid_segment_pos=[[0, 1, 2, 3, 11, 12, 13, 14, 27, 28, 29, 30, 0, 0, 0, 0]]`\n",
    "    - Ignore the 0s at the end of the two arrays as they are just for padding to a static length\n",
    "3. Find locations where a q segment change/break happens. A segment change happens when: \n",
    "    - there is a change in valid_segment_ids OR \n",
    "    - `valid_segment_pos[i+1] != valid_segment_pos[i]`\n",
    "    - `segment_changes=[[True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]`\n",
    "4. Perform a cumulative sum on the segment changes: \n",
    "    - `new_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 5, 6, 7]]`\n",
    "5. Filter out the valid indices only and pad at the end with 0s upto static length (these are our “new” segment indices without padding)\n",
    "    - `new_segment_ids_filtered=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]`\n",
    "    - Notice here that the large chunk of 8 q token rows (rows 9-16 in Fig 5) gets broken down into 2 \"new\" segments of 4 q token rows each,\n",
    "    which is a pattern that cuDNN supports and it ensures that wasted computation for padded regions of rows 9-12 is not performed, which was the\n",
    "    case in Fig 3\n",
    "6. Perform a bin count and pad with -1s upto `max_num_segments_per_seq_for_rank`\n",
    "    - `seqlens_with_neg1_padding[[ 4, 4, 4, -1, -1, -1, -1]]`\n",
    "\n",
    "\n",
    "#### II. Implementation logic for kv_seqoffsets_for_striped_for_rank()\n",
    "**What is the objective/logic ?**\n",
    "- Get the original segment ids for those locations where segment changes happen (arr1)\n",
    "    - Each segment has a known kv offset, hence if we know which original segment id a \"new\" segment is associated with we can find it's kv offset\n",
    "    - So, for e.g., in Fig 5, all valid tokens of seg_3 have the same kv offset, so even if this gets split into a 2 \"new\" segments, we can procure the offset for both using a mapping of original seg-ids to kv offset \n",
    "- Get the segment ids for those locations where segment changes happen in the AG tensor (arr2)\n",
    "    - This is used to create a kind of mapping between original seg-ids to kv offset\n",
    "- Pick values from arr2 mapping for the \"new\" segment ids collected in arr1\n",
    "\n",
    "**Example walkthrough**\n",
    "1. Find locations where a kv segment pos change/break happens and mask out zero seg ids. A segment change happens when: \n",
    "    - `kv_segment_pos[i+1] != kv_segment_pos[i]`\n",
    "    - `segment_changes_masked=[[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]`\n",
    "2. Get the indices where the segment changes happen and the segment ids associated with them:\n",
    "    - `segment_changes_indices=[[0, 8, 12, -1, -1, -1, -1, -1, -1]]`\n",
    "    - `[[1, 2, 2, -1, -1, -1, -1, -1, -1]]`\n",
    "3. Find the segment pos changes/break for the AG seg pos and mask out zero seg ids\n",
    "    - `segment_changes_masked_ag=[[True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]`\n",
    "4. Get indices where the segment changes happen for the AG seg pos (this works as a mapping between segment ids and kv offsets)\n",
    "    - `segment_changes_ag_indices=[[0, 21, -1, -1, -1, -1, -1, -1, -1]]`\n",
    "5. Get the seq offsets by indexing into segment_changes_ag_indices using segment_changes_indices :\n",
    "    - `kv_seq_offsets[[0, 21, 21, -1, -1, -1, -1, -1, -1]]`\n",
    "\n",
    "The implementation details for `q_seqoffsets_for_striped_for_rank()` and `kv_seqlens_for_striped_for_rank()` can be found in [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
